# _*_ coding: utf-8 _*_
# @Date : 2023/3/20 23:43
# @Author : Paul
# @File : regression.py
# @Description :

import pandas as pd
import io
import matplotlib.pyplot as plt
from core.utils.string_utils import StringUtils
from sklearn.model_selection import train_test_split
from sklearn.impute import SimpleImputer

from core.algo.base_algo import BaseAlgo
from core.data_source.meta_data_source.meta_data_source import MetaDataSource
from core.utils.data_souce_init_utils import DataSourceInitUtil
from core.utils.date_util import DateUtil


class Regression(BaseAlgo):

    def __init__(self,
                 app_name="clusters",
                 data_source_id=None,
                 table_name=None,
                 feature_cols=None,
                 class_col=None,
                 train_size=None,
                 param = None
                 ):
        """
        初始化类
        :param app_name:
        :param data_source_id:
        :param table_name:
        :param train_cols:
        """
        super(Regression, self).__init__(app_name=app_name)
        self.param = param
        # 开始时间
        self.start_time = DateUtil.getCurrentDate()
        self.table_name = table_name
        self.feature_cols = feature_cols
        self.class_col = class_col
        self.train_size = train_size
        self.all_col = self.feature_cols + self.class_col
        # 模型类
        self.reg = None
        # 数据的摘要概要
        self.info = None
        # 数据统计学估计
        self.describe = None
        # 预测效果分布图
        self.regress_pred_image = self.image_path + "regress_pred_" + app_name + "_" + DateUtil.getCurrentDateSimple() + ".png"
        # 获取元数据库
        self.meta_data_source = MetaDataSource()
        # 获取训练集所在的数据源
        self.data_source = DataSourceInitUtil.getDataBase(self.meta_data_source,
                                                          data_source_id)
        self.labels = None
        self.train_data_ratio = float(self.param["trainDataRatio"])

    def getModelData(self):
        """
        获取建模数据：输出训练集、测试集
        :return:
        """
        data_query_sql = "select {} from {}".format(",".join(self.all_col),
                         self.table_name)
        data = self.data_source.queryAll(data_query_sql)
        data = pd.DataFrame(data=data,
                                       columns = self.all_col)

        # 数据的简要摘要
        buf = io.StringIO()  # 创建一个StringIO，便于后续在内存中写入str
        data.info(buf=buf)  # 写入
        self.info = buf.getvalue()  # 读取

        # 统计学估计
        self.describe = data.describe()

        # 获取预处理策略值
        process_method_list_after_process = []
        self.param.get("preProcessMethodList")[0].get("preProcessFeature")
        process_method_list = self.param.get("preProcessMethodList")
        if len(process_method_list) > 0:
            for process_method in process_method_list:
                if process_method == None or process_method == "null":
                    continue
                pre_process_feature = process_method.get("preProcessFeature")
                if StringUtils.isBlack(pre_process_feature):
                    continue
                else:
                    process_method_list_after_process.append(process_method)
        self.param["preProcessMethodList"] = process_method_list_after_process
        if len(process_method_list_after_process) > 0:
            for process_method in process_method_list_after_process:
                pre_process_feature = process_method.get("preProcessFeature")
                preProcessMethod = process_method.get("preProcessMethod")
                preProcessMethodValue = process_method.get("preProcessMethodValue")

                #1.删除填充值
                if preProcessMethod == "deletena":
                    data.drop(pre_process_feature, inplace=True, axis=1)
                #2.替换缺失值
                elif preProcessMethod == "fillna":
                    if preProcessMethodValue == "mean":
                        imp_mean = SimpleImputer()
                        data[pre_process_feature] = imp_mean.fit_transform(data[pre_process_feature].values.reshape(-1,1))
                    elif preProcessMethodValue == "median":
                        imp_median = SimpleImputer(strategy="median")
                        data[pre_process_feature] = imp_median.fit_transform(data[pre_process_feature].values.reshape(-1,1))
                    elif preProcessMethodValue == "most_frequent":
                        imp_mode = SimpleImputer(strategy="most_frequent")
                        data[pre_process_feature] = imp_mode.fit_transform(data[pre_process_feature].values.reshape(-1,1))
                    elif preProcessMethodValue == "constant_0":
                        imp_0 = SimpleImputer(strategy="constant", fill_value=0)
                        data[pre_process_feature] = imp_0.fit_transform(data[pre_process_feature].values.reshape(-1,1))
                    elif preProcessMethodValue == "constant_1":
                        imp_1 = SimpleImputer(strategy="constant", fill_value=1)
                        data[pre_process_feature] = imp_1.fit_transform(data[pre_process_feature].values.reshape(-1,1))
                # 3.分类变量转换为数值变量
                elif preProcessMethod == "transClassFeature":
                    unique_value = data[pre_process_feature].unique().tolist()
                    data[pre_process_feature] = data[pre_process_feature].apply(lambda x: unique_value.index(x))
                # 4.类型转换
                elif preProcessMethod == "transType":
                    if preProcessMethodValue == "int":
                        data[pre_process_feature] = data[pre_process_feature].astype("int")
                    elif preProcessMethodValue == "float":
                        data[pre_process_feature] = data[pre_process_feature].astype("float")

        X = data.iloc[:, data.columns != self.class_col[0]]
        Y = data.iloc[:, data.columns == self.class_col[0]]
        # 数据无纲量化策略
        standardization = self.param["standardization"]
        if standardization == "MinMaxScaler":
            from sklearn.preprocessing import MinMaxScaler
            scaler = MinMaxScaler()
            X = scaler.fit_transform(X)
        elif standardization == "StandardScaler":
            from sklearn.preprocessing import StandardScaler
            scaler = StandardScaler()
            X = scaler.fit_transform(X)

        Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, Y, train_size=self.train_data_ratio)
        return [Xtrain, Ytrain], [Xtest, Ytest]


if __name__ == '__main__':
    regression = Regression(app_name="cluster_demo",
                      data_source_id=9,
                      table_name="titanic",
                         feature_cols=["Survived", "Pclass", "Sex", "Age", "Cabin"],
                      class_col=["Embarked"],
                      train_size=0.7)
    regression.getModelData()